510679
@@ -17,7 +17,7 @@
  */
 package org.apache.hadoop.hive.ql.optimizer.calcite.rules;
 
-import java.util.HashSet;
+import java.util.ArrayList;
 import java.util.List;
 import java.util.Set;
 
@@ -29,7 +29,6 @@
 import org.apache.calcite.rel.core.JoinRelType;
 import org.apache.calcite.rel.core.RelFactories;
 import org.apache.calcite.rel.core.RelFactories.FilterFactory;
-import org.apache.calcite.rel.type.RelDataType;
 import org.apache.calcite.rex.RexBuilder;
 import org.apache.calcite.rex.RexNode;
 import org.apache.calcite.rex.RexUtil;
@@ -91,14 +90,12 @@
public void onMatch(RelOptRuleCall call) {
     } catch (CalciteSemanticException e) {
       return;
     }
-
-    Set<Integer> joinLeftKeyPositions = new HashSet<Integer>();
-    Set<Integer> joinRightKeyPositions = new HashSet<Integer>();
-    for (int i = 0; i < joinPredInfo.getEquiJoinPredicateElements().size(); i++) {
-      JoinLeafPredicateInfo joinLeafPredInfo = joinPredInfo.
-              getEquiJoinPredicateElements().get(i);
-      joinLeftKeyPositions.addAll(joinLeafPredInfo.getProjsFromLeftPartOfJoinKeysInChildSchema());
-      joinRightKeyPositions.addAll(joinLeafPredInfo.getProjsFromRightPartOfJoinKeysInChildSchema());
+    
+    List<RexNode> leftJoinExprsList = new ArrayList<>();
+    List<RexNode> rightJoinExprsList = new ArrayList<>();
+    for (JoinLeafPredicateInfo joinLeafPredicateInfo : joinPredInfo.getEquiJoinPredicateElements()) {
+        leftJoinExprsList.addAll(joinLeafPredicateInfo.getJoinExprs(0));
+        rightJoinExprsList.addAll(joinLeafPredicateInfo.getJoinExprs(1));
     }
 
     // Build not null conditions
@@ -107,10 +104,10 @@
public void onMatch(RelOptRuleCall call) {
 
     Set<String> leftPushedPredicates = Sets.newHashSet(registry.getPushedPredicates(join, 0));
     final List<RexNode> newLeftConditions = getNotNullConditions(cluster,
-            rexBuilder, join.getLeft(), joinLeftKeyPositions, leftPushedPredicates);
+            rexBuilder, leftJoinExprsList, leftPushedPredicates);
     Set<String> rightPushedPredicates = Sets.newHashSet(registry.getPushedPredicates(join, 1));
     final List<RexNode> newRightConditions = getNotNullConditions(cluster,
-            rexBuilder, join.getRight(), joinRightKeyPositions, rightPushedPredicates);
+            rexBuilder, rightJoinExprsList, rightPushedPredicates);
 
     // Nothing will be added to the expression
     RexNode newLeftPredicate = RexUtil.composeConjunction(rexBuilder, newLeftConditions, false);
@@ -142,21 +139,16 @@
public void onMatch(RelOptRuleCall call) {
   }
 
   private static List<RexNode> getNotNullConditions(RelOptCluster cluster,
-          RexBuilder rexBuilder, RelNode input, Set<Integer> inputKeyPositions,
+          RexBuilder rexBuilder, List<RexNode> inputJoinExprs,
           Set<String> pushedPredicates) {
     final List<RexNode> newConditions = Lists.newArrayList();
-    for (int pos : inputKeyPositions) {
-      RelDataType keyType = input.getRowType().getFieldList().get(pos).getType();
-      // Nothing to do if key cannot be null
-      if (!keyType.isNullable()) {
-        continue;
-      }
-      RexNode cond = rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL,
-              rexBuilder.makeInputRef(input, pos));
-      String digest = cond.toString();
-      if (pushedPredicates.add(digest)) {
-        newConditions.add(cond);
-      }
+
+    for (RexNode rexNode : inputJoinExprs) {
+        RexNode cond = rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, rexNode);
+        String digest = cond.toString();
+        if (pushedPredicates.add(digest)) {
+            newConditions.add(cond);
+        }
     }
     return newConditions;
   }
